Skip to main content

Classification of documents

This is the core of the project - classifying the document into a subcategory.

Once this is done, the expiry date can be easily computed.

Tech stack

The packages and tools used for this level are:

  • pandas: To read and manipulate dataframes

  • langchain: To generate dynamic prompt and create LLM instance

  • AWS Bedrock: To access LLMs including Claude and Titan

  • FAISS: To vectorize prompt section created from preprocessed text

  • RetrievalQA: To generate dynamic prompt

Category examples

First, we made a dictionary containing the examples of all the categories to act as context for the calssification.

df = pd.read_csv("processed_Simplified.csv")

ref_examples_dict = {}
for index, row in df.iterrows():
ref = row['Ref']
ref_examples_dict[ref] = {}
ref_examples_dict[ref]["data_class"] = row['Data Class']
types = [x.strip() for x in row['Data Type'].split(',')]
examples = [x.strip() for x in row['Examples'].split(',')]
ref_examples_dict[ref]["data_type"] = types
ref_examples_dict[ref]["example"] = examples

LLM Prompt

There are too many categories to perform zero-shot classification. Hence, we have to pass a prompt containing the preprocessed document text and categories with examples to an LLM, telling it to classify the document.

def make_prompt(categories_with_examples):
prompt = "### Document Classification\n"
prompt += "Classes, Doc Types and Examples:\n"
for category, data in categories_with_examples.items():
prompt += f"- {category}:\n"
prompt += f" {data['data_class']}\n"
prompt += " Doc Types:\n"
for data_type in data['data_type']:
prompt += f" - {data_type}\n"
prompt += " Doc Examples:\n"
for example in data['example']:
prompt += f" - {example}\n"

prompt += "###\n"

return prompt

Accessing LLMs

We use AWS Bedrock to access the LLMs. Several options are available - Claude, Titan, etc. We use Titan to generate the embeddings and Claude for the actual classification.

But first, we must create the LLM instance and embeddings.

import boto3
from langchain.embeddings import BedrockEmbeddings
from langchain.llms.bedrock import Bedrock

boto3_bedrock = boto3.client('bedrock-runtime', region_name='us-east-1')

# We will be using the Titan Embeddings Model to generate our Embeddings.
llm = Bedrock(
model_id="anthropic.claude-v2",
client=boto3_bedrock,
model_kwargs={
"max_tokens_to_sample": 100,
"temperature": 0
}
)

bedrock_embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1", client=boto3_bedrock)

Next, we use the bedrock embeddings along with the make_prompt function (with the ref_example_dict dictionary) to create a FAISS wrapper around our prompt.

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS

text_splitter = RecursiveCharacterTextSplitter(
# Set a really small chunk size, just to show.
chunk_size = 1000,
chunk_overlap = 100,
)

vectorstore_faiss = FAISS.from_documents(
text_splitter.create_documents([make_prompt(ref_examples_dict)]),
bedrock_embeddings,
)

Dynamic prompting

Finally, we generate the dynamic RAG prompt that will perform the classification.

from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

prompt_template = """

Human:Classify given text into categories based on the PDF text you are given.
It can belong to only one category; use examples and data type to detect the category.
Give only the category code and nothing else.\n

If you don't know the answer, return None.
<context>
{context}
</context

Question: {question}

Assistant:"""

PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)

qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore_faiss.as_retriever(
search_type="mmr"
),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)

Classification

We can finally perform the classification now

text_classification = {}

for file in file_contents.keys():
try:
query = file_contents[file]["imp_words"]
result = qa({"query": query})
for key in ref_examples_dict.keys():
if key in result["result"]:
text_classification[file] = key
except:
continue